import os, sys, time
import shutil
import yaml

base = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(base, '../'))

import argparse
import chainer
import numpy as np
import threading
from PIL import Image
from chainer import training
import chainer.functions as F
from chainer import Variable
from chainer.training import extension
from chainer.training import extensions

import source.yaml_utils as yaml_utils
from source.miscs.random_samples import sample_continuous, sample_categorical
from evaluation import load_inception_model
from source.inception.inception_score import inception_score
# from source.inception.inception_score_tf import get_inception_score as inception_score_tf
# from source.inception.inception_score_tf import get_mean_and_cov as get_mean_cov_tf
from evaluation import get_mean_cov as get_mean_cov_chainer
from evaluation import FID

_RUN_BASELINE = False

def load_models(config):
    gen_conf = config.models['generator']
    gen = yaml_utils.load_model(gen_conf['fn'], gen_conf['name'], gen_conf['args'])
    dis_conf = config.models['binary_discriminator']
    dis = yaml_utils.load_model(dis_conf['fn'], dis_conf['name'], dis_conf['args'])
    return gen, dis

def e_grad(z, P, gen, dis, alpha, ret_e=False):
    logp_z = F.sum(P.log_prob(z), 1, keepdims=True)
    x = gen(batchsize=z.shape[0], z=z)
    d = dis(x)
    E = -logp_z - alpha * d
    grad = chainer.grad((E,), (z,))
    # prior_grad = chainer.grad((-logp_z, ), (z, ))
    # d_grad = chainer.grad((d, ), (z, ))
    if ret_e:
        return E, grad
    return grad

def langevin_dynamics(z, gen, dis, alpha, n_steps, step_lr, eps_std):
    z_sp = []
    xp = gen.xp
    P = None
    batch_size, z_dim = z.shape
    if gen.distribution == "normal":
        P = chainer.distributions.Normal(xp.zeros((z_dim, ), dtype=xp.float32),
                                         xp.ones((z_dim, ), dtype=xp.float32))
    else:
        raise NotImplementedError(gen.distribution)
    prev_e = None
    for _ in range(n_steps):
        if _ % 10 == 0:
            z_sp.append(z)
        # eps = xp.sqrt(step_lr * 2) * xp.random.randn(batch_size, z_dim).astype(xp.float32)
        eps = eps_std * xp.random.randn(batch_size, z_dim).astype(xp.float32)
        E, grad = e_grad(z, P, gen, dis, alpha, ret_e=True)
        z = z - step_lr * grad[0] + eps
    z_sp.append(z)
    print(n_steps, len(z_sp), z.shape)
    return z_sp


def langevin_sample(gen, dis, config, n=50000, batchsize=100):
    ims = []
    zs = []
    xp = gen.xp
    alpha = config.langevin['alpha']
    n_steps = config.langevin['n_steps']
    step_lr = config.langevin['step_lr']
    eps_std = config.langevin['eps_std']
    for i in range(0, n, batchsize):
        # with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
        with chainer.using_config('train', False):
            z = Variable(gen.sample_z(batchsize))
            z_sp = langevin_dynamics(z, gen, dis, alpha, n_steps, step_lr, eps_std)
        x = gen(batchsize, z_sp[-1])
        x = chainer.cuda.to_cpu(x.data)
        x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8)
        ims.append(x)
        zs.append(np.stack([chainer.cuda.to_cpu(o.data) for o in z_sp], axis=0))
        if i % 50 == 0:
            print(i)
    ims = np.asarray(ims)
    zs = np.stack(zs, axis=0)
    _, _, _, h, w = ims.shape
    ims = ims.reshape((n, 3, h, w))
    return ims, zs

def parallel_apply(modules, config, n_list, devices, batchsize=100):
    lock = threading.Lock()
    results = {}

    def _worker(pid, module, n, device):
        try:
            with chainer.using_device(device):
                gen, dis = module
                ims, zs = langevin_sample(gen, dis, config, n, batchsize)
            with lock:
                results[pid] = (ims, zs)
        except Exception as e:
            with lock:
                results[pid] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, n, device))
                   for i, (module, n, device) in
                   enumerate(zip(modules, n_list, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], n_list[0], devices[0])

    im_outputs = []
    z_outputs = []
    for i in range(len(modules)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        ims, zs = output
        im_outputs.append(ims)
        z_outputs.append(zs)
    return im_outputs, z_outputs

def langevin_sample_multigpu(gen, dis, config, gpu_list, n=50000, batchsize=100):
    gen_list = [gen.copy() for _ in gpu_list]
    dis_list = [dis.copy() for _ in gpu_list]
    for gpu_id, gen in zip(gpu_list, gen_list):
        gen.to_gpu(gpu_id)
    for gpu_id, dis in zip(gpu_list, dis_list):
        dis.to_gpu(gpu_id)
    modules = list(zip(gen_list, dis_list))

    n_gpu = len(gpu_list)
    n_list = [n // batchsize // n_gpu for _ in gpu_list]
    n_list[0] = n_list[0] + (n // batchsize) % n_gpu
    n_list = [n * batchsize for n in n_list]
    print('n_list', n_list)

    ims, zs = parallel_apply(modules, config, n_list, gpu_list, batchsize)
    ims = np.vstack(ims)
    zs = np.concatenate(zs, axis=0)
    return ims, zs

def langevin_sample_vis(gen, dis, config, dst, rows=10, cols=10, seed=0):
    """Visualization of rows*cols images randomly generated by the generator."""
    @chainer.training.make_extension()
    def make_image(trainer):
        np.random.seed(seed)
        n_images = rows * cols
        x = langevin_sample(gen, dis, config, n_images, batchsize=n_images)
        _, _, h, w = x.shape
        x = x.reshape((rows, cols, 3, h, w))
        x = x.transpose(0, 3, 1, 4, 2)
        x = x.reshape((rows * h, cols * w, 3))
        preview_dir = '{}/preview'.format(dst)
        preview_path = preview_dir + '/image{:0>8}.png'.format(trainer.updater.iteration)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)
        Image.fromarray(x).save(preview_path)

    return make_image


def baseline_gen_images(gen, n=50000, batchsize=100):
    ims = []
    xp = gen.xp
    for i in range(0, n, batchsize):
        with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
            x = gen(batchsize)
        x = chainer.cuda.to_cpu(x.data)
        x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8)
        ims.append(x)
    ims = np.asarray(ims)
    _, _, _, h, w = ims.shape
    ims = ims.reshape((n, 3, h, w))
    return ims

def calc_inception_score(ims, inception_model, splits, dst=None, exp_name=''):    
    mean, std = inception_score(inception_model, ims, splits=splits)
    # mean, std = inception_score_tf(ims, splits=splits)
    eval_ret = {
        'inception_mean': mean,
        'inception_std': std
    }
    if dst is not None:
        preview_dir = '{}/stats'.format(dst)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)
        preview_path = preview_dir + '/inception_score_{}.txt'.format(exp_name)
        np.savetxt(preview_path, np.array([mean, std]))
    return eval_ret

def eval_inception(gen, dis, config, n_images, dst, gpu_list, splits=10, batchsize=250, path=None, exp_name=''):
    if _RUN_BASELINE:
        gen.to_gpu(gpu_list[0])
        ims = baseline_gen_images(gen, n_images, batchsize=batchsize).astype("f")
    else:
        ims, zs = langevin_sample_multigpu(gen, dis, config, gpu_list, n_images, batchsize=batchsize)
        ims = ims.astype('f')
        zs = np.reshape(np.transpose(zs, axes=(1, 0, 2, 3)), (zs.shape[1], -1, zs.shape[-1]))
        if dst is not None:
            preview_dir = '{}/latents'.format(dst)
            if not os.path.exists(preview_dir):
                os.makedirs(preview_dir)
            save_path = preview_dir + '/{}_latent_samples'.format(exp_name)
            np.save(save_path, zs)
    model = load_inception_model(path)
    model.to_gpu(gpu_list[0])
    eval_ret = calc_inception_score(ims, model, splits, dst)
    return eval_ret


def eval_inception_with_zs(gen, dis, config, n_images, dst, gpu_list, splits=10, batchsize=250, path=None, exp_name=''):
    preview_dir = '{}/latents'.format(dst)
    save_path = preview_dir + '/{}_latent_samples.npy'.format(exp_name)
    zs = np.load(save_path)
    model = load_inception_model(path)
    model.to_gpu(gpu_list[0])
    gen.to_gpu(gpu_list[0])
    xp = gen.xp
    for z_iter in range(0, zs.shape[0]):
        ims = []
        for batch_idx in range(0, n_images, batchsize):
            z_batch = xp.asarray(zs[z_iter, batch_idx : (batch_idx + batchsize), :])
            x = gen(batchsize, z_batch)
            x = chainer.cuda.to_cpu(x.data)
            x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8)
            ims.append(x)
        ims = np.asarray(ims)        
        _, _, _, h, w = ims.shape
        ims = ims.reshape((n_images, 3, h, w)).astype("f")
        print(ims.shape)
        eval_ret = calc_inception_score(ims, model, splits, dst)
        print('z_step', z_iter * 10, eval_ret)

def eval_fid_with_zs(gen, dis, config, n_images, dst, gpu_list, splits=10, batchsize=250, path=None, exp_name=''):
    _use_tf = False
    if _use_tf:
        get_mean_cov = get_mean_cov_tf
    else:
        get_mean_cov = get_mean_cov_chainer
        model = load_inception_model(path)
        model.to_gpu(gpu_list[0])
        gen.to_gpu(gpu_list[0])
        xp = gen.xp
    preview_dir = '{}/latents'.format(dst)
    save_path = preview_dir + '/{}_latent_samples.npy'.format(exp_name)
    stat_file = 'pretrained_models/cifar10/unconditional/cifar-10-fid.npz'
    zs = np.load(save_path)
    stat = np.load(stat_file)
    for z_iter in range(zs.shape[0]):
        ims = []
        for batch_idx in range(0, n_images, batchsize):
            z_batch = xp.asarray(zs[z_iter, batch_idx : (batch_idx + batchsize), :])
            x = gen(batchsize, z_batch)
            x = chainer.cuda.to_cpu(x.data)
            x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8)
            ims.append(x)
        ims = np.asarray(ims)        
        _, _, _, h, w = ims.shape
        ims = ims.reshape((n_images, 3, h, w)).astype("f")
        fid_n = 5000
        fids = []
        for k in range(0, n_images, fid_n):
            x = ims[k : k + fid_n]
            if _use_tf:
                mean, cov = get_mean_cov(x)
            else:
                with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
                    mean, cov = get_mean_cov(model, x, batch_size=batchsize)
            fid = FID(stat["mean"], stat["cov"], mean, cov)
            fids.append(fid)
            break

        print('z_step', z_iter * 10, np.mean(fids), fids)
        # np.savetxt(os.path.join(dst, ) '{}/fid_{}.txt'.format(args.results_dir, c), np.array([fid]))

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file')
    parser.add_argument('--gpu', type=str, default='0', help='index of gpu to be used')
    parser.add_argument('--data_dir', type=str, default='./data/imagenet')
    parser.add_argument('--results_dir', type=str, default='./results/gans',
                        help='directory to save the results to')
    parser.add_argument('--inception_model_path', type=str, default='./datasets/inception_model',
                        help='path to the inception model')
    parser.add_argument('--snapshot', type=str, default='',
                        help='path to the snapshot')
    parser.add_argument('--loaderjob', type=int,
                        help='number of parallel data loading processes')
    parser.add_argument('--gen_ckpt', type=str,
                        help='path to the saved generator snapshot model file to load')
    parser.add_argument('--dis_ckpt', type=str,
                        help='path to the saved discriminator snapshot model file to load')
    parser.add_argument('--exp_name', type=str,
                        help='name of the experiment')
    parser.add_argument('--splits', type=int, default=10)
    parser.add_argument('--rows', type=int, default=10)
    parser.add_argument('--cols', type=int, default=10)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--vis', action='store_true')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--eval_zs', action='store_true')
    parser.add_argument('--eval_fid', action='store_true')
    parser.add_argument('--baseline', action='store_true')

    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path)))

    gpus = list(map(int, args.gpu.split(',')))
    chainer.cuda.get_device_from_id(gpus[0]).use()

    #Models
    gen, dis = load_models(config)
    chainer.serializers.load_npz(args.gen_ckpt, gen)
    chainer.serializers.load_npz(args.dis_ckpt, dis)

    # gen.to_gpu(device=args.gpu)
    # dis.to_gpu(device=args.gpu)
    # models = {"gen": gen, "dis": dis}

    out = args.results_dir
    save_path = os.path.join(out, '{}.png'.format(args.exp_name))
    if not os.path.exists(out):
        os.makedirs(out)
    if args.vis:
        rows = args.rows
        cols = args.cols
        seed = args.seed
        n_images = rows * cols
        x = langevin_sample(gen, dis, config, n_images, batchsize=n_images)
        _, _, h, w = x.shape
        x = x.reshape((rows, cols, 3, h, w))
        x = x.transpose(0, 3, 1, 4, 2)
        x = x.reshape((rows * h, cols * w, 3))
        Image.fromarray(x).save(save_path)
    if args.eval:
        n_images = int(5000 * args.splits)
        path = args.inception_model_path
        exp_name = args.exp_name
        ret = eval_inception(gen, dis, config, n_images, dst=out, gpu_list=gpus, splits=args.splits, path=path, exp_name=exp_name)
        print(ret)
    if args.eval_zs:
        n_images = int(5000 * args.splits)
        path = args.inception_model_path
        exp_name = args.exp_name
        ret = eval_inception_with_zs(gen, dis, config, n_images, dst=out, gpu_list=gpus, splits=args.splits, path=path, exp_name=exp_name)
    
    if args.eval_fid:
        n_images = int(5000 * args.splits)
        path = args.inception_model_path
        exp_name = args.exp_name
        ret = eval_fid_with_zs(gen, dis, config, n_images, dst=out, gpu_list=gpus, splits=args.splits, path=path, exp_name=exp_name)

if __name__ == '__main__':
    main()